#!/usr/bin/python
# -*- coding:utf-8 -*-
#
# Python driver for GDEM035T81 3.5" e-Paper display
# Controller: SSD1685
# Resolution: 384x184
#
# Based on GxEPD2 driver by Jean-Marc Zingg:
# https://github.com/ZinggJM/GxEPD2/blob/master/src/gdey/GxEPD2_290_GDEY029T71H.cpp
# Adapted for Waveshare HAT on Raspberry Pi (spidev + RPi.GPIO)

import RPi.GPIO as GPIO
import spidev
import time

# Pin definitions (Waveshare HAT standard pinout)
RST_PIN  = 17
DC_PIN   = 25
CS_PIN   = 8
BUSY_PIN = 24

# Display resolution
WIDTH  = 184
HEIGHT = 384

# SSD1685 source shift (OTP set for wider panel; shift to align to actual TFT)
SOURCE_SHIFT = 8  # adjust if image is offset horizontally


class EPD:
    def __init__(self):
        self.reset_pin = RST_PIN
        self.dc_pin    = DC_PIN
        self.cs_pin    = CS_PIN
        self.busy_pin  = BUSY_PIN
        self.width     = WIDTH
        self.height    = HEIGHT
        self._init_done = False

        GPIO.setmode(GPIO.BCM)
        GPIO.setwarnings(False)
        GPIO.setup(self.reset_pin, GPIO.OUT)
        GPIO.setup(self.dc_pin,    GPIO.OUT)
        GPIO.setup(self.cs_pin,    GPIO.OUT)
        GPIO.setup(self.busy_pin,  GPIO.IN)

        self.spi = spidev.SpiDev()
        self.spi.open(0, 0)
        self.spi.max_speed_hz = 4000000
        self.spi.mode = 0b00

    # ------------------------------------------------------------------ low level

    def _reset(self):
        GPIO.output(self.reset_pin, GPIO.HIGH)
        time.sleep(0.2)
        GPIO.output(self.reset_pin, GPIO.LOW)
        time.sleep(0.002)
        GPIO.output(self.reset_pin, GPIO.HIGH)
        time.sleep(0.2)

    def _send_command(self, cmd):
        GPIO.output(self.dc_pin, GPIO.LOW)
        GPIO.output(self.cs_pin, GPIO.LOW)
        self.spi.writebytes([cmd])
        GPIO.output(self.cs_pin, GPIO.HIGH)

    def _send_data(self, data):
        GPIO.output(self.dc_pin, GPIO.HIGH)
        GPIO.output(self.cs_pin, GPIO.LOW)
        if isinstance(data, int):
            self.spi.writebytes([data])
        else:
            # send in chunks to avoid SPI buffer limits
            for i in range(0, len(data), 4096):
                self.spi.writebytes(data[i:i+4096])
        GPIO.output(self.cs_pin, GPIO.HIGH)

    def _wait_busy(self, timeout_ms=5000):
        """BUSY pin: LOW = busy, HIGH = idle (SSD1685)"""
        deadline = time.time() + timeout_ms / 1000.0
        while GPIO.input(self.busy_pin) == GPIO.LOW:
            if time.time() > deadline:
                print("WARNING: busy timeout")
                break
            time.sleep(0.01)

    # ------------------------------------------------------------------ RAM area

    def _set_ram_area(self, x, y, w, h):
        x += SOURCE_SHIFT
        # x increase, y increase (normal scan direction)
        self._send_command(0x11)
        self._send_data(0x03)
        # X start / end (byte addresses)
        self._send_command(0x44)
        self._send_data(x // 8)
        self._send_data((x + w) // 8 - 1)
        # Y start / end (two bytes each, LSB first)
        self._send_command(0x45)
        self._send_data(y % 256)
        self._send_data(y // 256)
        self._send_data((y + h - 1) % 256)
        self._send_data((y + h - 1) // 256)
        # Set RAM x/y counters
        self._send_command(0x4E)
        self._send_data(x // 8)
        self._send_command(0x4F)
        self._send_data(y % 256)
        self._send_data(y // 256)

    # ------------------------------------------------------------------ init

    def init(self):
        self._reset()
        time.sleep(0.01)

        self._send_command(0x12)  # SWRESET
        time.sleep(0.01)

        # Driver output control
        self._send_command(0x01)
        self._send_data((self.height - 1) % 256)
        self._send_data((self.height - 1) // 256)
        self._send_data(0x00)

        # Border waveform
        self._send_command(0x3C)
        self._send_data(0x05)

        # Use built-in temperature sensor
        self._send_command(0x18)
        self._send_data(0x80)

        # Display update control (normal, no bypass)
        self._send_command(0x21)
        self._send_data(0x00)
        self._send_data(0x00)

        self._set_ram_area(0, 0, self.width, self.height)
        self._init_done = True

    # ------------------------------------------------------------------ clear

    def clear(self, color=0xFF):
        """Clear screen. color=0xFF = white, 0x00 = black."""
        if not self._init_done:
            self.init()
        buf = [color] * (self.width * self.height // 8)
        # write to both current (0x24) and previous (0x26) buffers
        for cmd in (0x26, 0x24):
            self._set_ram_area(0, 0, self.width, self.height)
            self._send_command(cmd)
            self._send_data(buf)
        self._refresh_full()

    # ------------------------------------------------------------------ display

    def getbuffer(self, image):
        buf = [0xFF] * (self.width * self.height // 8)

        image = image.convert('1')
        image = image.rotate(90, expand=True)
        image = image.resize((self.width, self.height))

        pixels = image.load()

        stride = (self.width + 7) // 8  # <-- FIX

        for y in range(self.height):
            for x in range(self.width):
                if pixels[x, y] == 0:
                    index = (x // 8) + y * stride
                    buf[index] &= ~(0x80 >> (x % 8))

        return buf

    def display(self, buf):
        """Send buffer and do a full refresh."""
        if not self._init_done:
            self.init()
        for cmd in (0x26, 0x24):  # <-- IMPORTANT
       	    self._set_ram_area(0, 0, self.width, self.height)
            self._send_command(cmd)
            self._send_data(buf)
        self._refresh_full()

    def display_partial(self, buf):
        """Send buffer and do a fast partial refresh."""
        if not self._init_done:
            self.init()
        self._set_ram_area(0, 0, self.width, self.height)
        self._send_command(0x24)
        self._send_data(buf)
        self._refresh_partial()

    # ------------------------------------------------------------------ refresh

    def _refresh_full(self):
        # bypass RED channel, use fast full update with temperature
        self._send_command(0x21)
        self._send_data(0x40)  # bypass RED as 0
        self._send_data(0x00)
        # write temperature register (110 dec = 0x6E = ~110°C forces fast LUT)
        self._send_command(0x1A)
        self._send_data(0x6E)
        self._send_data(0x00)
        # load temperature value
        self._send_command(0x22)
        self._send_data(0x91)
        self._send_command(0x20)
        time.sleep(0.002)
        # full update sequence
        self._send_command(0x22)
        self._send_data(0xC7)
        self._send_command(0x20)
        self._wait_busy(5000)

    def _refresh_partial(self):
        self._send_command(0x21)
        self._send_data(0x00)
        self._send_data(0x00)
        self._send_command(0x22)
        self._send_data(0xDC)
        self._send_command(0x20)
        self._wait_busy(1000)

    # ------------------------------------------------------------------ power

    def sleep(self):
        self._send_command(0x10)  # deep sleep
        self._send_data(0x01)
        time.sleep(2)

    def close(self):
        self.spi.close()
        GPIO.cleanup()



if __name__ == '__main__':
    import logging
    logging.basicConfig(level=logging.INFO)
    from PIL import Image

    epd = EPD()
    print("init...")
    epd.init()
    print("clear...")
    epd.clear()
    print("loading image...")
    img = Image.open('test.png')
    print("displaying...")
    epd.display(epd.getbuffer(img))
    print("sleep...")
    epd.sleep()
    epd.close()
    print("done!")
